import subprocess, os
import time
import getpass
import yaml
import sys
import csv

utils_path = os.path.join('..', 'lib', 'utilities')
sys.path.append(utils_path)

import load_configurations
config_user_dic = load_configurations.open_yaml("../config_user.yaml")

# Specifying paths to shared directories for pipeline.
shared_simulations = str(config_user_dic["external"]["shared_simulations"])
shared_estimations = str(config_user_dic["external"]["shared_estimations"])
missing_run     = config_user_dic["local"]["missing_run"]

def generate_parameter_sets():
    parameter_sets = []
    seed_list = list(range(1, 1101))
    suffix = ''
    for market in [10000]:
    
        custom_version_price_model_list = [
            ("", "mw", "L"),
            ("", "price", "L"),
            ("", "sg", "L"),
            ("", "coarse", "L"),
            ("", "prod", "L"),
            ("", "mw", "L_nm"),
            ("", "sg", "L_nm"),
            ("", "coarse", "L_nm"),
            ("", "mw", "RCL"),
            ("", "sg", "RCL"),
            ("_sh", "mw", "L"),
            ("_sh", "sg", "L"),
            ("_sh", "prod", "L")
        ]

        for version, price_iv, model in custom_version_price_model_list:
            if model == 'RCL':
                par_rest_list = [[2,3]]
            else:
                par_rest_list = [[]]
            
            for par_rest in par_rest_list:
                for seed in seed_list:
                    parameters = (price_iv, model, suffix, seed, market, par_rest, version)
                    parameter_sets.append(parameters)
                    
    return parameter_sets

def generate_filenames(parameters):
    price_iv, model, suffix, seed, market, par_rest, version = parameters
    flavor = int((seed-1)/100) + 1
    iteration = seed - (flavor-1) * 100
    filename = f'estimation_{price_iv}_{model}{"".join(map(str, par_rest))}_k{flavor}_s{iteration}_m{market}{version}.mat'
    return filename

def check_file_exists(filename, output_dir):
    filepath = os.path.join(output_dir, filename)
    return os.path.exists(filepath)

def submit_job(market_list, price_iv_list, model_list, par_rest_list, seed_list, version_list, batch_number):
        
    # Move to code directory:
    os.chdir("code")

    data_path = shared_simulations
    output_path = shared_estimations
    user = getpass.getuser()
    
    slurm_script = f"""#!/bin/bash
#SBATCH -p gentzkow,hns
#SBATCH --job-name=missing_batch_{batch_number}
#SBATCH --output=missing_{batch_number}.out
#SBATCH --error=missing_{batch_number}.err
#SBATCH --nodes=1
#SBATCH --ntasks={len(price_iv_list)}
#SBATCH --mem=32g
#SBATCH --time=15:00:00
#SBATCH --qos=normal

ml purge
ml system git
ml system git-lfs
ml curl/7.81.0
ml system rclone
ml load matlab
source /home/users/{user}/miniconda3/bin/activate blp-instruments
"""
    
    script_filename = f"missing_batch_{batch_number}.sh"
    script_filename = clean_filename(script_filename)   
    with open(script_filename, "w") as f:
        f.write(slurm_script)

        for i in range(len(price_iv_list)):
            srun_command = f"""
cd EstimateUnified 
srun  matlab -nodisplay -nosplash -nodesktop \\
    -r "EstimateUnified('{price_iv_list[i]}', '{model_list[i]}', {par_rest_list[i]}, {seed_list[i]}, '{data_path}', '{output_path}', {market_list[i]}, '{version_list[i]}'); exit;"\n
cd -
"""
            f.write(srun_command)

        f.close()
    
    job = "missing_batch_%s" % batch_number 
    cmd = "sbatch --parsable %s" % script_filename

    print("Submitting job for %s" % job)

    status, jid = subprocess.getstatusoutput(cmd)

    if (status == 0):
        print("Job id for %s is %s" % (job, jid))
    else:
        print("Error submitting job for %s" % job)
        print(subprocess.getstatusoutput(cmd))
    
    os.chdir("..")

def clean_filename(filename):
    filename = filename.replace("(", "")
    filename = filename.replace(")", "")
    filename = filename.replace(",", "")
    filename = filename.replace("'", "")
    filename = filename.replace(" ", "_")
    return filename

def main():
    parameter_sets = generate_parameter_sets()
    missing_files = []

    # Get list of missing files:
    for parameters in parameter_sets:
        filename = generate_filenames(parameters)

        if not check_file_exists(filename, shared_estimations):
            missing_files.append(parameters)
            print(filename)

            
    # Submit missing files:
    print(f"There are {len(missing_files)} missing files.")

    csv_filename = 'output/missing_estimation.csv'
    header = ['price_iv', 'model', 'suffix', 'seed', 'market', 'par_rest', 'version_list']
    with open(csv_filename, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(header) 
        for parameters in missing_files:
            writer.writerow(parameters) 

    if missing_run:
        # Group in batches of 1:
        for i in range(0, len(missing_files), 1):
            batch_number = int(i/1) + 1
            batch = missing_files[i:i+1]
            # Get list of lists of parameters:
            parameter_sets = [parameters for parameters in batch]
            # Split into lists of parameters:
            price_iv_list, model_list, suffix_list, seed_list, market_list, par_rest_list, version_list = zip(*parameter_sets)

            submit_job(market_list, price_iv_list, model_list, par_rest_list, seed_list, version_list, batch_number)

if __name__ == "__main__":
    main()
